from transformers import AutoTokenizer

import pickle
import csv
import json
import jsonlines
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import sys

random.seed(42)


def pad_tokens(input_ids,block_size,tokenizer):
    if len(input_ids) >= block_size-2:
        input_ids = input_ids[0:block_size-2]
    elif len(input_ids) < block_size-2:
        input_ids = input_ids+[tokenizer.pad_token_id]*(block_size-2-len(input_ids))
    return input_ids 

inpfile = sys.argv[1]
outfile = sys.argv[1]+".tensor.jsonl"
tokenizer = AutoTokenizer.from_pretrained("roberta-base", use_fast=True)

kws = json.load(open("all_kws.json"))
sents = json.load(open("all_sents.json"))


with jsonlines.open(inpfile) as ifd, jsonlines.open(outfile,"w") as ofd:
    clen = 25
    qlen = 25
    alen = 10
    
    print(f"{clen} {qlen} {alen}")
    
    for row in tqdm(ifd,"Writing TokenIds"):
        ctxt_tokens = tokenizer.encode(row["context"])
        qtxt_tokens = tokenizer.encode(row["question"])
        atxt_tokens = tokenizer.encode(row["answer"])
        if "neg_context" in row:
            neg_ctxts =  [tokenizer.encode(x) for x in row["neg_context"]]
            neg_ques =  [tokenizer.encode(x) for x in row["neg_question"]]
            neg_ans =  [tokenizer.encode(x) for x in row["neg_answer"]]
        else:
            neg_ctxts =  [tokenizer.encode(x) for x in random.sample(sents,10)]
            neg_ques =  [tokenizer.encode(x) for x in random.sample(sents,10)]
            neg_ans =  [tokenizer.encode(x) for x in random.sample(kws,10)]
        
        ofd.write({
            "ctxt":ctxt_tokens, "qtxt": qtxt_tokens, "atxt": atxt_tokens, "nctxt": neg_ctxts, "nqtxt": neg_ques, "natxt": neg_ans
        })